{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch import nn\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "import json\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import *\n",
    "from data import *\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os. environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "device = torch.device('cuda:0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **set params**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.0001\n",
    "batch_size=128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **load train data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "news_vert = np.load('../../data2/news_vert.npy')\n",
    "news_title = np.load('../../data2/news_title.npy')\n",
    "news_body = np.load('../../data2/news_body.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../../data2/TrainUsers.pkl', 'rb') as f:\n",
    "    TrainUsers = pickle.load(f)\n",
    "with open('../../data2/TrainSamples.pkl', 'rb') as f:\n",
    "    TrainSamples = pickle.load(f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **load model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../../data2/dict.pkl', 'rb') as f:\n",
    "    _,category_dict,word_dict = pickle.load(f)\n",
    "with open('../../data2/news.pkl', 'rb') as f:\n",
    "    news = pickle.load(f)\n",
    "embedding_matrix = np.load('../../data2/embedding_matrix.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = NRMS(embedding_matrix)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **begin training**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def acc(y_true, y_hat):\n",
    "    y_hat = torch.argmax(y_hat, dim=-1)\n",
    "    tot = y_true.shape[0]\n",
    "    hit = torch.sum(y_true == y_hat)\n",
    "    return hit.data.float() * 1.0 / tot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters(), lr=0.0001) #lr = 0.0001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ed: 762880, train_loss: 0.58991, acc: 0.67979: 100%|██████████| 5963/5963 [24:23<00:00,  4.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 tensor(0.5898, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ed: 762880, train_loss: 0.55482, acc: 0.71245: 100%|██████████| 5963/5963 [24:57<00:00,  3.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 tensor(0.5547, device='cuda:0')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "ed: 762880, train_loss: 0.54282, acc: 0.72187: 100%|██████████| 5963/5963 [24:39<00:00,  4.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 tensor(0.5427, device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "min_train_loss = 100.0\n",
    "for ep in range(1,4):\n",
    "    loss = 0.0\n",
    "    accuary = 0.0\n",
    "    cnt = 1\n",
    "    dset = TrainDataset(TrainUsers, TrainSamples, news_title, news_vert, news_body)\n",
    "    data_loader = DataLoader(dset, batch_size=128, collate_fn=collate_fn, shuffle=True)\n",
    "    tqdm_util = tqdm(data_loader)\n",
    "    for user_feature, news_feature, label in tqdm_util: \n",
    "        user_feature = [i.to(device) for i in user_feature]\n",
    "        news_feature = [i.to(device) for i in news_feature]\n",
    "        label = label.to(device)\n",
    "        bz_loss, y_hat = model(user_feature, news_feature, label)\n",
    "        loss += bz_loss.data.float()\n",
    "        accuary += acc(label, y_hat)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        bz_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if cnt % 10 == 0:\n",
    "            tqdm_util.set_description('ed: {}, train_loss: {:.5f}, acc: {:.5f}'.format(cnt * batch_size, loss.data / cnt, accuary / cnt))\n",
    "        cnt += 1\n",
    "    loss /= cnt\n",
    "    print(ep, loss)\n",
    "    torch.save(model.state_dict(), '../../runs/userencoder/NAML-{}.pkl'.format(ep))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- **load test data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../../data2/ValidUsers.pkl', 'rb') as f:\n",
    "    ValidUsers = pickle.load(f)\n",
    "with open('../../data2/ValidSamples.pkl', 'rb') as f:\n",
    "    ValidSamples = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 223/223 [00:02<00:00, 88.28it/s]\n",
      "100%|██████████| 782/782 [00:02<00:00, 381.92it/s]\n",
      "100%|██████████| 100000/100000 [01:29<00:00, 1117.10it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.0\n",
      "1\n",
      "AUC\t MRR\t nDCG5\t nDCG10\t CTR1\t CTR10\t\n",
      "(0.6407997589410694, 0.23081682854624697, 0.2535878624744038, 0.33227711742274696, 0.1454, 0.11970133333333333)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 223/223 [00:02<00:00, 84.54it/s]\n",
      "100%|██████████| 782/782 [00:02<00:00, 376.80it/s]\n",
      "100%|██████████| 100000/100000 [01:30<00:00, 1098.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.0\n",
      "2\n",
      "AUC\t MRR\t nDCG5\t nDCG10\t CTR1\t CTR10\t\n",
      "(0.6460787629807131, 0.23452229193374297, 0.2586624248644371, 0.3377814121898208, 0.14922, 0.12179533333333331)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 223/223 [00:02<00:00, 86.55it/s]\n",
      "100%|██████████| 782/782 [00:02<00:00, 366.82it/s]\n",
      "100%|██████████| 100000/100000 [01:29<00:00, 1118.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.0\n",
      "3\n",
      "AUC\t MRR\t nDCG5\t nDCG10\t CTR1\t CTR10\t\n",
      "(0.6464411767491163, 0.23536873814293183, 0.25995044819703567, 0.33848005066956344, 0.14956, 0.12227333333333334)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for ep in range(1,4):\n",
    "    model = NRMS(embedding_matrix)\n",
    "    model = model.to(device)\n",
    "    model.load_state_dict(torch.load('../../runs/userencoder/NAML-{}.pkl'.format(ep)))\n",
    "    model.eval() \n",
    "    # save new embedding matrix\n",
    "    np.save('../../data2/embedding_matrix{}.npy'.format(ep), model.embed.weight.data.cpu().numpy())\n",
    "    \n",
    "\n",
    "    n_dset = news_dataset(news_title, news_vert, news_body)\n",
    "    news_data_loader = DataLoader(n_dset, batch_size=512, collate_fn=news_collate_fn, shuffle=False)\n",
    "\n",
    "      \n",
    "    news_scoring = []\n",
    "    torch.cuda.empty_cache()\n",
    "    with torch.no_grad():\n",
    "        for news_feature in tqdm(news_data_loader): \n",
    "            news_feature = [i.to(device) for i in news_feature]\n",
    "            news_vec = model.news_encoder(news_feature)\n",
    "            news_vec = news_vec.to(torch.device(\"cpu\")).detach().numpy()\n",
    "            news_scoring.extend(news_vec)\n",
    "    news_scoring = np.array(news_scoring)\n",
    "    np.save('../../data2/news_scoring{}.npy'.format(ep), news_scoring)\n",
    "    \n",
    "    u_dset = UserDataset(news_scoring, ValidUsers)\n",
    "    user_data_loader = DataLoader(u_dset, batch_size=128, shuffle=False)\n",
    "\n",
    "    user_scoring = []\n",
    "    with torch.no_grad():\n",
    "        for user_feature in tqdm(user_data_loader): \n",
    "            user_feature = user_feature.to(device)\n",
    "            user_vec = model.user_encoder(user_feature)\n",
    "            user_vec = user_vec.to(torch.device(\"cpu\")).detach().numpy()\n",
    "            user_scoring.extend(user_vec)\n",
    "    user_scoring = np.array(user_scoring)\n",
    "    np.save('../../data2/global_user_embed{}.npy'.format(ep),np.mean(user_scoring,axis=0))\n",
    "    g = evaluate(user_scoring,news_scoring, ValidSamples)\n",
    "    print(ep)\n",
    "    print('AUC\\t', 'MRR\\t', 'nDCG5\\t', 'nDCG10\\t','CTR1\\t','CTR10\\t')\n",
    "    print(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.3 64-bit ('py3': conda)",
   "language": "python",
   "name": "python373jvsc74a57bd0f1ea14d0bd9c7a0f4f2d255c0662c7e1119328c505d1c17d9d8f159415dfcf69"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
